use crate::{
    ir::{Node, NodeType},
    node::{
        argmax::argmax_update_outputs, argmin::argmin_update_outputs,
        attention::attention_update_output, bernoulli::bernoulli_update_output,
        cast::cast_update_outputs, comparison::elementwise_comparison_outputs,
        concat::concat_update_outputs, constant::constant_update_outputs,
        constant_of_shape::constant_of_shape_update_output,
        depth_to_space::depth_to_space_update_outputs, expand::expand_update_outputs,
        flatten::flatten_update_outputs, gather::gather_update_outputs, gemm::gemm_output_shape,
        linear::linear_update_outputs, matmul::matmul_update_outputs,
        one_hot::one_hot_output_shape, random::random_update_output,
        random_like::random_like_update_output, range::range_update_outputs,
        reduce::reduce_update_outputs, reshape::reshape_update_outputs,
        shape::shape_update_outputs, size::size_update_outputs, slice::slice_update_output_rank,
        space_to_depth::space_to_depth_update_outputs, split::split_update_outputs,
        squeeze::squeeze_update_output, topk::top_k_update_output,
        unsqueeze::unsqueeze_update_output, where_op::where_update_outputs,
    },
    util::{same_as_input, same_as_input_broadcast, temporary_pass_through_stub},
};

/// Infer the rank of each output tensor and update them based solely on rank inference.
pub fn rank_inference(node: &mut Node) {
    log::debug!("Inferring rank for node: {}", node.name);

    match node.node_type {
        NodeType::Add => same_as_input_broadcast(node),
        NodeType::ArgMax => argmax_update_outputs(node),
        NodeType::ArgMin => argmin_update_outputs(node),
        NodeType::Attention => attention_update_output(node),
        NodeType::AveragePool1d => same_as_input(node),
        NodeType::AveragePool2d => same_as_input(node),
        NodeType::BatchNormalization => same_as_input(node),
        NodeType::BitShift => same_as_input_broadcast(node),
        NodeType::BitwiseAnd => same_as_input_broadcast(node),
        NodeType::BitwiseNot => same_as_input(node),
        NodeType::BitwiseOr => same_as_input_broadcast(node),
        NodeType::BitwiseXor => same_as_input_broadcast(node),
        NodeType::Bernoulli => bernoulli_update_output(node),
        NodeType::Cast => cast_update_outputs(node),
        NodeType::Ceil => same_as_input(node),
        NodeType::Clip => same_as_input(node),
        NodeType::Concat => concat_update_outputs(node),
        NodeType::Constant => constant_update_outputs(node),
        NodeType::ConstantOfShape => constant_of_shape_update_output(node),
        NodeType::Conv1d => same_as_input(node),
        NodeType::Conv2d => same_as_input(node),
        NodeType::Cos => same_as_input(node),
        NodeType::Cosh => same_as_input(node),
        NodeType::Div => same_as_input_broadcast(node),
        NodeType::Dropout => same_as_input(node),
        NodeType::Equal => elementwise_comparison_outputs(node),
        NodeType::Erf => same_as_input(node),
        NodeType::Exp => same_as_input(node),
        NodeType::Expand => expand_update_outputs(node),
        NodeType::Floor => same_as_input(node),
        NodeType::Flatten => flatten_update_outputs(node),
        NodeType::Gelu => same_as_input(node),
        NodeType::Gather => gather_update_outputs(node),
        NodeType::GatherElements => same_as_input(node),
        NodeType::Gemm => gemm_output_shape(node),
        NodeType::Greater => elementwise_comparison_outputs(node),
        NodeType::GreaterOrEqual => elementwise_comparison_outputs(node),
        NodeType::HardSigmoid => same_as_input(node),
        NodeType::GlobalAveragePool => same_as_input(node),
        NodeType::ConvTranspose1d => same_as_input(node),
        NodeType::ConvTranspose2d => same_as_input(node),
        NodeType::InstanceNormalization => same_as_input(node),
        NodeType::IsInf => elementwise_comparison_outputs(node),
        NodeType::IsNaN => elementwise_comparison_outputs(node),
        NodeType::LayerNormalization => same_as_input(node),
        NodeType::GroupNormalization => same_as_input(node),
        NodeType::DepthToSpace => depth_to_space_update_outputs(node),
        NodeType::LeakyRelu => same_as_input(node),
        NodeType::Less => elementwise_comparison_outputs(node),
        NodeType::LessOrEqual => elementwise_comparison_outputs(node),
        NodeType::Linear => linear_update_outputs(node),
        NodeType::Log => same_as_input(node),
        NodeType::LogSoftmax => same_as_input(node),
        NodeType::MatMul => matmul_update_outputs(node),
        NodeType::Max => same_as_input_broadcast(node),
        NodeType::MaxPool1d => same_as_input(node),
        NodeType::MaxPool2d => same_as_input(node),
        NodeType::Min => same_as_input_broadcast(node),
        NodeType::Mul => same_as_input_broadcast(node),
        NodeType::Neg => same_as_input(node),
        NodeType::Not => same_as_input(node),
        NodeType::And => same_as_input(node),
        NodeType::Or => same_as_input(node),
        NodeType::Xor => same_as_input(node),
        NodeType::OneHot => one_hot_output_shape(node),
        NodeType::Pad => same_as_input(node),
        NodeType::PRelu => same_as_input_broadcast(node),
        NodeType::Pow => same_as_input_broadcast(node),
        NodeType::RandomNormal => random_update_output(node),
        NodeType::RandomNormalLike => random_like_update_output(node),
        NodeType::RandomUniform => random_update_output(node),
        NodeType::RandomUniformLike => random_like_update_output(node),
        NodeType::Range => range_update_outputs(node),
        NodeType::Reciprocal => same_as_input(node),
        NodeType::ReduceMax => reduce_update_outputs(node),
        NodeType::ReduceMin => reduce_update_outputs(node),
        NodeType::ReduceMean => reduce_update_outputs(node),
        NodeType::ReduceProd => reduce_update_outputs(node),
        NodeType::ReduceSum => reduce_update_outputs(node),
        NodeType::ReduceSumSquare => reduce_update_outputs(node),
        NodeType::ReduceL1 => reduce_update_outputs(node),
        NodeType::ReduceL2 => reduce_update_outputs(node),
        NodeType::ReduceLogSum => reduce_update_outputs(node),
        NodeType::ReduceLogSumExp => reduce_update_outputs(node),
        NodeType::Relu => same_as_input(node),
        NodeType::Reshape => reshape_update_outputs(node),
        NodeType::Resize => same_as_input(node),
        NodeType::Round => same_as_input(node),
        NodeType::Shape => shape_update_outputs(node),
        NodeType::Sigmoid => same_as_input(node),
        NodeType::Sign => same_as_input(node),
        NodeType::Sin => same_as_input(node),
        NodeType::Sinh => same_as_input(node),
        NodeType::Size => size_update_outputs(node),
        NodeType::Slice => slice_update_output_rank(node),
        NodeType::Softmax => same_as_input(node),
        NodeType::SpaceToDepth => space_to_depth_update_outputs(node),
        NodeType::Split => split_update_outputs(node),
        NodeType::Squeeze => squeeze_update_output(node),
        NodeType::Sqrt => same_as_input(node),
        NodeType::Sub => same_as_input_broadcast(node),
        NodeType::Sum => same_as_input_broadcast(node),
        NodeType::Tan => same_as_input(node),
        NodeType::Tanh => same_as_input(node),
        NodeType::TopK => top_k_update_output(node),
        NodeType::Transpose => same_as_input(node),
        NodeType::Trilu => same_as_input(node),
        NodeType::Unsqueeze => unsqueeze_update_output(node),
        NodeType::Where => where_update_outputs(node),
        _ => temporary_pass_through_stub(node),
    }

    log::debug!(
        "Rank inference result for {}: {:?}",
        node.name,
        node.outputs
    );
}
